{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "19cbad15-bdc7-4b74-8f6b-6e6da75fec35",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "## Interpretability - Text Explainers\n",
    "\n",
    "In this example, we use LIME and Kernel SHAP explainers to explain a text classification model.\n",
    "\n",
    "First we import the packages and define some UDFs and a plotting function we will need later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "a2689fb5-2425-430d-8261-6e39598b6505",
     "showTitle": false,
     "title": ""
    },
    "vscode": {
     "languageId": "python"
    }
   },
   "outputs": [],
   "source": [
    "from pyspark.sql.functions import *\n",
    "from pyspark.sql.types import *\n",
    "from pyspark.ml import Pipeline\n",
    "from pyspark.ml.classification import LogisticRegression\n",
    "from pyspark.ml.functions import vector_to_array\n",
    "from synapse.ml.explainers import *\n",
    "from synapse.ml.featurize.text import TextFeaturizer\n",
    "from synapse.ml.core.platform import *\n",
    "\n",
    "vec_access = udf(lambda v, i: float(v[i]), FloatType())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "1f52a610-0695-48c2-9de9-e60f239dd5c7",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "Load training data, and convert rating to binary label."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "a02806b1-e0ba-4b6f-93bf-5d3eb635e43e",
     "showTitle": false,
     "title": ""
    },
    "vscode": {
     "languageId": "python"
    }
   },
   "outputs": [],
   "source": [
    "data = (\n",
    "    spark.read.parquet(\"wasbs://publicwasb@mmlspark.blob.core.windows.net/BookReviewsFromAmazon10K.parquet\")\n",
    "    .withColumn(\"label\", (col(\"rating\") > 3).cast(LongType()))\n",
    "    .select(\"label\", \"text\")\n",
    "    .cache()\n",
    ")\n",
    "\n",
    "display(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "87a536da-b4b4-4c79-b6a3-5f7b1ad7428a",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "We train a text classification model, and randomly sample 10 rows to explain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "9a2fb867-194d-4660-b655-6373ec7272bf",
     "showTitle": false,
     "title": ""
    },
    "vscode": {
     "languageId": "python"
    }
   },
   "outputs": [],
   "source": [
    "train, test = data.randomSplit([0.60, 0.40])\n",
    "\n",
    "pipeline = Pipeline(\n",
    "    stages=[\n",
    "        TextFeaturizer(\n",
    "            inputCol=\"text\",\n",
    "            outputCol=\"features\",\n",
    "            useStopWordsRemover=True,\n",
    "            useIDF=True,\n",
    "            minDocFreq=20,\n",
    "            numFeatures=1 << 16,\n",
    "        ),\n",
    "        LogisticRegression(maxIter=100, regParam=0.005, labelCol=\"label\", featuresCol=\"features\"),\n",
    "    ]\n",
    ")\n",
    "\n",
    "model = pipeline.fit(train)\n",
    "\n",
    "prediction = model.transform(test)\n",
    "\n",
    "explain_instances = prediction.orderBy(rand()).limit(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "3a9fbdc8-9660-4337-b3eb-7c717aabf0cc",
     "showTitle": false,
     "title": ""
    },
    "vscode": {
     "languageId": "python"
    }
   },
   "outputs": [],
   "source": [
    "def plotConfusionMatrix(df, label, prediction, classLabels):\n",
    "    from synapse.ml.plot import confusionMatrix\n",
    "    import matplotlib.pyplot as plt\n",
    "\n",
    "    fig = plt.figure(figsize=(4.5, 4.5))\n",
    "    confusionMatrix(df, label, prediction, classLabels)\n",
    "    if running_on_synapse():\n",
    "        plt.show()\n",
    "    else:\n",
    "        display(fig)\n",
    "\n",
    "\n",
    "plotConfusionMatrix(model.transform(test), \"label\", \"prediction\", [0, 1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "50c294f1-439a-455e-bff4-25c65822a575",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "First we use the LIME text explainer to explain the model's predicted probability for a given observation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "63623d84-8d6d-4f5b-8e2b-83e21866fb26",
     "showTitle": false,
     "title": ""
    },
    "vscode": {
     "languageId": "python"
    }
   },
   "outputs": [],
   "source": [
    "lime = TextLIME(\n",
    "    model=model,\n",
    "    outputCol=\"weights\",\n",
    "    inputCol=\"text\",\n",
    "    targetCol=\"probability\",\n",
    "    targetClasses=[1],\n",
    "    tokensCol=\"tokens\",\n",
    "    samplingFraction=0.7,\n",
    "    numSamples=2000,\n",
    ")\n",
    "\n",
    "lime_results = (\n",
    "    lime.transform(explain_instances)\n",
    "    .select(\"tokens\", \"weights\", \"r2\", \"probability\", \"text\")\n",
    "    .withColumn(\"probability\", vec_access(\"probability\", lit(1)))\n",
    "    .withColumn(\"weights\", vector_to_array(col(\"weights\").getItem(0)))\n",
    "    .withColumn(\"r2\", vec_access(\"r2\", lit(0)))\n",
    "    .withColumn(\"tokens_weights\", arrays_zip(\"tokens\", \"weights\"))\n",
    ")\n",
    "\n",
    "display(lime_results.select(\"probability\", \"r2\", \"tokens_weights\", \"text\").orderBy(col(\"probability\").desc()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "de0ff3a3-fc93-4769-8b98-2fab81467fcc",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "Then we use the Kernel SHAP text explainer to explain the model's predicted probability for a given observation.\n",
    "\n",
    "> Notice that we drop the base value from the SHAP output before displaying the SHAP values. The base value is the model output for an empty string."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "inputWidgets": {},
     "nuid": "9d3fd01d-f140-465e-ae53-d3b25f246e4d",
     "showTitle": false,
     "title": ""
    },
    "vscode": {
     "languageId": "python"
    }
   },
   "outputs": [],
   "source": [
    "shap = TextSHAP(\n",
    "    model=model,\n",
    "    outputCol=\"shaps\",\n",
    "    inputCol=\"text\",\n",
    "    targetCol=\"probability\",\n",
    "    targetClasses=[1],\n",
    "    tokensCol=\"tokens\",\n",
    "    numSamples=5000,\n",
    ")\n",
    "\n",
    "shap_results = (\n",
    "    shap.transform(explain_instances)\n",
    "    .select(\"tokens\", \"shaps\", \"r2\", \"probability\", \"text\")\n",
    "    .withColumn(\"probability\", vec_access(\"probability\", lit(1)))\n",
    "    .withColumn(\"shaps\", vector_to_array(col(\"shaps\").getItem(0)))\n",
    "    .withColumn(\"shaps\", slice(col(\"shaps\"), lit(2), size(col(\"shaps\"))))\n",
    "    .withColumn(\"r2\", vec_access(\"r2\", lit(0)))\n",
    "    .withColumn(\"tokens_shaps\", arrays_zip(\"tokens\", \"shaps\"))\n",
    ")\n",
    "\n",
    "display(shap_results.select(\"probability\", \"r2\", \"tokens_shaps\", \"text\").orderBy(col(\"probability\").desc()))"
   ]
  }
 ],
 "metadata": {
  "application/vnd.databricks.v1+notebook": {
   "dashboards": [],
   "language": "python",
   "notebookMetadata": {
    "pythonIndentUnit": 2
   },
   "notebookName": "Interpretability - Text Explainers",
   "notebookOrigID": 913802417841163,
   "widgets": {}
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": ""
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
